{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d093f55-b308-47d5-b255-a47f1df1b2df",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "import sys\n",
    "\n",
    "# Add the project's files to the python path\n",
    "# file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))  # for .py script\n",
    "file_path = os.path.dirname(os.path.abspath(''))  # for .ipynb notebook\n",
    "sys.path.append(file_path)\n",
    "\n",
    "import hydra\n",
    "from src.utils import init_config\n",
    "import torch\n",
    "from src.transforms import *\n",
    "from src.utils.widgets import *\n",
    "from src.data import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ded143fa-64ec-4bb7-8122-a51a5f6e96dd",
   "metadata": {},
   "source": [
    "## Select your device, experiment, split, and pretrained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6714791-04da-489d-aaab-021954242ac2",
   "metadata": {},
   "outputs": [],
   "source": [
    "device_widget = make_device_widget()\n",
    "task_widget, expe_widget = make_experiment_widgets()\n",
    "split_widget = make_split_widget()\n",
    "ckpt_widget = make_checkpoint_file_search_widget()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d10a3df-dc50-48e3-b297-637cfcd8fffa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Summarizing selected task, experiment, split, and checkpoint\n",
    "print(f\"You chose:\")\n",
    "print(f\"  - device={device_widget.value}\")\n",
    "print(f\"  - task={task_widget.value}\")\n",
    "print(f\"  - split={split_widget.value}\")\n",
    "print(f\"  - experiment={expe_widget.value}\")\n",
    "print(f\"  - ckpt={ckpt_widget.value}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3496cc8f-7fa4-481b-b4f6-5dc48993494a",
   "metadata": {},
   "source": [
    "## Parsing the config files\n",
    "Hydra and OmegaConf are used to parse the `yaml` config files.\n",
    "\n",
    "❗Make sure you selected a **ckpt file relevant to your experiment** in the previous section. \n",
    "You can use our pretrained models for this, or your own checkpoints if you have already trained a model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d526f72f-4585-4ff9-9ce1-a7c1885e572f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parse the configs using hydra\n",
    "cfg = init_config(overrides=[\n",
    "    f\"experiment={task_widget.value}/{expe_widget.value}\",\n",
    "    f\"ckpt_path={ckpt_widget.value}\",\n",
    "    f\"datamodule.load_full_res_idx={True}\"  # only when you need full-resolution predictions \n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e4e98af-a7eb-4b40-b8e3-8e64d70498f5",
   "metadata": {},
   "source": [
    "## Datamodule and model instantiation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d06f3d5d-d42d-449c-b494-f5f5c71be4e4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Instantiate the datamodule\n",
    "datamodule = hydra.utils.instantiate(cfg.datamodule)\n",
    "datamodule.prepare_data()\n",
    "datamodule.setup()\n",
    "\n",
    "# Pick among train, val, and test datasets. It is important to note that\n",
    "# the train dataset produces augmented spherical samples of large \n",
    "# scenes, while the val and test dataset load entire tiles at once\n",
    "if split_widget.value == 'train':\n",
    "    dataset = datamodule.train_dataset\n",
    "elif split_widget.value == 'val':\n",
    "    dataset = datamodule.val_dataset\n",
    "elif split_widget.value == 'test':\n",
    "    dataset = datamodule.test_dataset\n",
    "else:\n",
    "    raise ValueError(f\"Unknown split '{split_widget.value}'\")\n",
    "\n",
    "# Print a summary of the datasets' classes\n",
    "dataset.print_classes()\n",
    "\n",
    "# Instantiate the model\n",
    "model = hydra.utils.instantiate(cfg.model)\n",
    "\n",
    "# Load pretrained weights from a checkpoint file\n",
    "if ckpt_widget.value is not None:\n",
    "    model = model._load_from_checkpoint(cfg.ckpt_path)\n",
    "\n",
    "# Move model to selected device\n",
    "model = model.eval().to(device_widget.value)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f95ac9f4-3266-47b9-909d-b189dabeec2d",
   "metadata": {},
   "source": [
    "## Hierarchical partition loading and inference\n",
    "SPT can process very large scenes at once.\n",
    "Depending on the dataset stage you use (train, val, or test), the inference will be run on a whole million-point tile or on a spherical sampling of it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4de618d2-d809-4c6a-b38c-7a97f66b3615",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# For the sake of visualization, we require that SPT does not \n",
    "# remove input Data attributes after moving them to Data.x, so we may \n",
    "# visualize them\n",
    "model.net.store_features = True\n",
    "\n",
    "# Load the first dataset item. This will return the hierarchical \n",
    "# partition of an entire tile, as a NAG object \n",
    "nag = dataset[0]\n",
    "\n",
    "# Apply on-device transforms on the NAG object. For the train dataset, \n",
    "# this will select a spherical sample of the larger tile and apply some\n",
    "# data augmentations. For the validation and test datasets, this will\n",
    "# prepare an entire tile for inference\n",
    "nag = dataset.on_device_transform(nag.to(device_widget.value))\n",
    "\n",
    "# Inference, returns a task-specific ouput object carrying predictions\n",
    "with torch.no_grad():\n",
    "    output = model(nag)\n",
    "\n",
    "# Compute the level-0 (voxel-wise) semantic segmentation predictions \n",
    "# based on the predictions on level-1 superpoints and save those for \n",
    "# visualization in the level-0 Data under the 'semantic_pred' attribute\n",
    "nag[0].semantic_pred = output.voxel_semantic_pred(super_index=nag[0].super_index)\n",
    "\n",
    "# Similarly, compute the level-0 panoptic segmentation predictions, if \n",
    "# relevant\n",
    "if task_widget.value == 'panoptic':\n",
    "    vox_y, vox_index, vox_obj_pred = output.voxel_panoptic_pred(super_index=nag[0].super_index)\n",
    "    nag[0].obj_pred = vox_obj_pred"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00b0e64f-4e1b-41cb-983c-993c5dc7fba4",
   "metadata": {},
   "source": [
    "By design, our model only needs to produce predictions for the $P_1$ superpoints for training. \n",
    "This conveniently saves compute and memory at training and evaluation time.\n",
    "\n",
    "At inference time however, we often **need the predictions on the $P_0$ voxel level or on the full-resolution input point cloud**.\n",
    "To this end, we provide helper functions to recover voxel-wise and full-resolution predictions.\n",
    "In the previous cell, for instance, `voxel_semantic_pred()` and `voxel_panoptic_pred()` were used for computing voxel-wise predictions and attaching them to our `NAG` object.\n",
    "\n",
    "In the following cell, we show how to efficiently recover the full-resolution predictions with `full_res_semantic_pred()` and `full_res_panoptic_pred()` (requires `datamodule.load_full_res_idx=True` in the config)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67965117-35c2-46c1-836f-21e5ce3f549e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the full-resolution semantic prediction. These labels are ordered \n",
    "# with respect to the full-resolution data points in the corresponding raw \n",
    "# input file. Note that we do not provide the pipeline for recovering the \n",
    "# corresponding full-resolution positions, colors, etc. \n",
    "raw_semseg_y = output.full_res_semantic_pred(\n",
    "    super_index_level0_to_level1=nag[0].super_index,\n",
    "    sub_level0_to_raw=nag[0].sub)\n",
    "\n",
    "# Similarly, we can compute the full-resolution panoptic prediction. \n",
    "# The returned outputs are (in order) the predicted semantic prediction, the\n",
    "# predicted instance index, and the InstancData object holding this information \n",
    "# under another format\n",
    "if task_widget.value == 'panoptic':\n",
    "    raw_pano_y, raw_index, raw_obj_pred = output.full_res_panoptic_pred(\n",
    "        super_index_level0_to_level1=nag[0].super_index, \n",
    "        sub_level0_to_raw=nag[0].sub)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c1ab511-dc3a-4750-bde8-5c6fcae7a516",
   "metadata": {},
   "source": [
    "## Visualizing an entire tile\n",
    "SPT can process very large scenes at once. Let's visualize the output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ae5bec5-01d2-41e3-90ec-244a61cac5d3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Visualize the hierarchical partition\n",
    "nag.show( \n",
    "    class_names=dataset.class_names,\n",
    "    class_colors=dataset.class_colors,\n",
    "    stuff_classes=dataset.stuff_classes,\n",
    "    num_classes=dataset.num_classes,\n",
    "    max_points=100000\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c986a13-ee00-4538-aa7d-26c5d4077dc5",
   "metadata": {},
   "source": [
    "However, for memory reasons, the visualization cannot display all points. Let's have a look at a smaller area."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15a653dd-291c-43c6-8c83-78ebff2d00b9",
   "metadata": {},
   "source": [
    "## Selecting a portion of the hierarchical partition\n",
    "The NAG structure can be subselected using `nag.select()`.\n",
    "\n",
    "This function expects an `int` specifying the partition level from which we should select, along with an index or a mask in the form or a `list`, `numpy.ndarray`, `torch.Tensor`, or `slice`.\n",
    "This index/mask describes which nodes to select at the specified level.\n",
    "\n",
    "The output NAG will only contain children, parents and edges of the selected nodes.\n",
    "\n",
    "Specifying `radius` and `center` for the `show()` function will make use of this `nag.select()` method internally."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca756c39-0d2d-495e-a7f1-b44ddbdaa346",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Predefined radius and center locations for each dataset\n",
    "# Feel free to modify these values\n",
    "center = nag[0].pos.mean(dim=0).view(1, -1)\n",
    "if 'dales' in expe_widget.value:\n",
    "    radius = 10\n",
    "elif 'kitti360' in expe_widget.value:\n",
    "    radius = 10\n",
    "elif 'scannet' in expe_widget.value:\n",
    "    radius = 10\n",
    "elif 's3dis' in expe_widget.value:\n",
    "    radius = 3\n",
    "else:\n",
    "    radius = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ef97ce1-527d-418d-b273-839c721ede6e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Visualize the sample\n",
    "nag.show(\n",
    "    radius=radius,\n",
    "    center=center,\n",
    "    class_names=dataset.class_names,\n",
    "    class_colors=dataset.class_colors,\n",
    "    stuff_classes=dataset.stuff_classes,\n",
    "    num_classes=dataset.num_classes,\n",
    "    max_points=100000\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca3eeadf-de22-47d7-b7e8-32e036738415",
   "metadata": {},
   "source": [
    "## Visualizing random samples centered on a class of interest\n",
    "You may be interested in seeing random samples of a given class. To this end, you can simply use the `BaseDataset.show_examples()` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e9f4709-3f5d-4743-b7dd-a30ce21ed56c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display some samples of a dataset, centered on a label of interest. \n",
    "# You may specify the class of interest as an int label or by the actual \n",
    "# class name \n",
    "datamodule.train_dataset.show_examples(7, radius=radius, max_examples=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abf42a3e-2ba8-46f7-9eda-a9a22bc05674",
   "metadata": {},
   "source": [
    "## Visualizing the superpoint graphs\n",
    "Let's have a closer look to visualize the graph connecting superpoints by setting `centroids=True` and `h_edge=True`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b9f9ebf-6549-4c7b-84dd-ab361d4f201a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Visualize the sample\n",
    "nag.show(\n",
    "    radius=radius,\n",
    "    center=center,\n",
    "    class_names=dataset.class_names,\n",
    "    class_colors=dataset.class_colors,\n",
    "    stuff_classes=dataset.stuff_classes,\n",
    "    num_classes=dataset.num_classes,\n",
    "    max_points=100000, \n",
    "    centroids=True, \n",
    "    h_edge=True, \n",
    "    h_edge_width=2\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08758db4-d865-4c84-b9f8-3adca6d3fdef",
   "metadata": {},
   "source": [
    "## Side-by-side visualization mode\n",
    "By setting `gap` to a chosen 3D offset, we can visualize all partition levels at once. Besides, setting `v_edge=True` will display the vertical edges connecting superpoints with their children."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cd261bc-486e-4c58-acaa-9a37ed24444f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Visualize the sample\n",
    "nag.show(\n",
    "    figsize=1000,\n",
    "    radius=radius,\n",
    "    center=center,\n",
    "    class_names=dataset.class_names,\n",
    "    class_colors=dataset.class_colors,\n",
    "    stuff_classes=dataset.stuff_classes,\n",
    "    num_classes=dataset.num_classes,\n",
    "    max_points=100000, \n",
    "    centroids=True, \n",
    "    v_edge=True, \n",
    "    v_edge_width=2, \n",
    "    gap=[0, 0, 4]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "def3d228-29fd-466a-b27d-e45b4de77429",
   "metadata": {},
   "source": [
    "## Exporting your visualization to HTML\n",
    "You can export your interactive visualization to HTML. \n",
    "You can then share your visualization, to be opened on any web browser with internet connection.\n",
    "\n",
    "To export a visualization, simply specify a `path` to which the file should be saved.\n",
    "Additionally, you may set a `title` to be displayed in your HTML."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be0113c8-b4a2-4033-b49f-49ff6cbbe3da",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Visualize the sample\n",
    "nag.show(\n",
    "    figsize=1600,\n",
    "    radius=radius,\n",
    "    center=center,\n",
    "    class_names=dataset.class_names,\n",
    "    class_colors=dataset.class_colors,\n",
    "    stuff_classes=dataset.stuff_classes,\n",
    "    num_classes=dataset.num_classes,\n",
    "    max_points=100000,\n",
    "    title=\"My Interactive Visualization Partition\", \n",
    "    path=\"my_interactive_visualization.html\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69df8eb7-d7f5-4e98-86ee-8dc231939b2a",
   "metadata": {},
   "source": [
    "## Going further with visualization\n",
    "See the commented code in `src.visualization` for more visualization options."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "spt_test",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
